{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html\n"
     ]
    }
   ],
   "source": [
    "import argparse\n",
    "import functools\n",
    "import json\n",
    "import math\n",
    "import pickle\n",
    "\n",
    "from os.path import isfile\n",
    "from os.path import join as pjoin\n",
    "from glob import glob\n",
    "from tqdm import tqdm\n",
    "from time import time\n",
    "\n",
    "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler\n",
    "from transformers import AutoModel, AutoTokenizer, BertConfig\n",
    "\n",
    "from utils_parsing import *\n",
    "from utils_caip import *\n",
    "\n",
    "from train_model import *\n",
    "\n",
    "from pprint import pprint\n",
    "from random import choice\n",
    "from time import time\n",
    "\n",
    "model_name = '../data/ttad_transformer_model/caip_model_dir/caip_test_model'\n",
    "\n",
    "args = pickle.load(open(model_name + '_args.pk', 'rb'))\n",
    "\n",
    "args.data_dir = '../data/ttad_transformer_model/annotated_data/'\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(args.pretrained_encoder_name)\n",
    "full_tree, tree_i2w = json.load(open(args.tree_voc_file))\n",
    "dataset = CAIPDataset(tokenizer, args, prefix=\"\", full_tree_voc=(full_tree, tree_i2w))\n",
    "\n",
    "enc_model = AutoModel.from_pretrained(args.pretrained_encoder_name)\n",
    "bert_config = BertConfig.from_pretrained(\"bert-base-uncased\")\n",
    "bert_config.is_decoder = True\n",
    "bert_config.vocab_size = len(tree_i2w) + 8\n",
    "bert_config.num_hidden_layers = args.num_decoder_layers\n",
    "dec_with_loss = DecoderWithLoss(bert_config, args, tokenizer)\n",
    "encoder_decoder = EncoderDecoderWithLoss(enc_model, dec_with_loss, args)\n",
    "encoder_decoder.load_state_dict(torch.load(model_name + '.pth'))\n",
    "\n",
    "encoder_decoder = encoder_decoder.cuda()\n",
    "_ = encoder_decoder.eval()\n",
    "\n",
    "def get_beam_tree(chat, noop_thres=0.95, beam_size=5, well_formed_pen=1e2):\n",
    "    btr = beam_search(chat, encoder_decoder, tokenizer, dataset, beam_size, well_formed_pen)\n",
    "    if btr[0][0].get('dialogue_type', 'NONE') == 'NOOP' and math.exp(btr[0][1]) < noop_thres:\n",
    "        tree = btr[1][0]\n",
    "    else:\n",
    "        tree = btr[0][0]\n",
    "    return tree\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'dialogue_type': 'HUMAN_GIVE_COMMAND',\n",
       " 'action_sequence': [{'action_type': 'MOVE'}]}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# one-word commands\n",
    "get_beam_tree(\"come\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'dialogue_type': 'HUMAN_GIVE_COMMAND',\n",
       " 'action_sequence': [{'action_type': 'BUILD',\n",
       "   'location': {'location_type': 'REFERENCE_OBJECT',\n",
       "    'relative_direction': 'BETWEEN',\n",
       "    'reference_object': {'has_name': [0, [5, 5]]}},\n",
       "   'schematic': {'has_name': [0, [2, 2]]}}]}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# with a small typo\n",
    "get_beam_tree(\"buildz an elevator between the mountains\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'dialogue_type': 'GET_MEMORY'}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# example error\n",
    "get_beam_tree(\"how big is the blue house?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'dialogue_type': 'GET_MEMORY'}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# reformulating to how it looks in templates\n",
    "get_beam_tree(\"what is the size of this blue house?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[({'dialogue_type': 'GET_MEMORY'}, -0.20432472229003906),\n",
       " ({'dialogue_type': 'NOOP'}, -6.626201152801514),\n",
       " ({'dialogue_type': 'GET_MEMORY',\n",
       "   'has_size': [0, [1, 1]],\n",
       "   'has_colour': [0, [4, 4]],\n",
       "   'reference_object': {'has_name': [0, [5, 5]]},\n",
       "   'repeat': {'stop_condition': {'condition_type': 'NEVER'}}},\n",
       "  -318.8774108886719),\n",
       " ({'dialogue_type': 'GET_MEMORY',\n",
       "   'has_size': [0, [1, 1]],\n",
       "   'has_colour': [0, [4, 4]],\n",
       "   'reference_object': {'has_name': [0, [5, 5]]},\n",
       "   'repeat': {'stop_condition': {'condition_type': 'NEVER'}}},\n",
       "  -325.29071044921875),\n",
       " ({'dialogue_type': 'GET_MEMORY',\n",
       "   'has_size': [0, [1, 1]],\n",
       "   'has_colour': [0, [4, 4]],\n",
       "   'reference_object': {'has_name': [0, [5, 5]]},\n",
       "   'repeat': {'stop_condition': {'condition_type': 'NEVER'}}},\n",
       "  -418.1733703613281)]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# the right interpretation is in the beam\n",
    "[(b, s) for b, s, _ in beam_search(\"how big is the blue house?\", encoder_decoder, tokenizer, dataset)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
